import pickle
import random
import numpy as np


def fill_in(mylist, indices, vals):
    assert len(indices) == len(vals)
    i = 0
    for ind in indices:
        mylist[ind] = vals[i]
        i += 1


def generate_vector_group(vector, start, threshold=0.001, V=5000, K=20):
    '''
    Return a group of 4 vectors, where vector 1 ~ vector 2 != vector 3, vector 4, and vector 3 != vector 4
    '''

    # sort values, top 250, top 1000, the rest
    V = len(vector)
    vector_sorted = sorted(list(vector))
    vector_sorted.reverse()

    top_vals = vector_sorted[:int(V / K)]
    mid_vals = vector_sorted[int(V / K):int(4 * V / K)]
    rest = vector_sorted[int(4 * V / K):]

    sim_vecs = []
    for _ in range(2):
        v = [0] * V
        # select indices to fill in the values
        # fill in the largest 250 values
        top_indices = sorted(list(np.random.choice(range(start, start + int(1.2 * V / K)), int(V / K), replace=False)))
        fill_in(v, top_indices, top_vals)

        # fill in the smallest 4000 values
        random.shuffle(rest)
        fill_in(v, list(range(0, start)) + list(range(start + int(4 * V / K), V)), rest)

        # fill in the middle values
        j = 0
        for i in range(start, start + int(4 * V / K)):
            if v[i] == 0:
                v[i] = mid_vals[j]
                j += 1

        sim_vecs.append(v)

    diff_vecs = []
    for _ in range(2):
        v = [0] * V

        # fill in top 1000 values
        vals = top_vals + mid_vals
        random.shuffle(vals)
        fill_in(v, list(range(start, start + int(4 * V / K))), vals)

        # fill in the smallest 4000 values
        random.shuffle(rest)
        fill_in(v, list(range(0, start)) + list(range(start + int(4 * V / K), V)), rest)

        diff_vecs.append(v)

    return sim_vecs[0], sim_vecs[1], diff_vecs[0], diff_vecs[1]

def get_matrix_CTM(alpha, V=5000, K=20):
    A = np.zeros((K, V))
    dir_vectors = np.random.dirichlet(np.ones(V)*alpha/K, int(K/4))
    for i in range(int(K/4)):
        v1, v2, v3, v4 = generate_vector_group(dir_vectors[i], start=int(4*V/K*i), V=V, K=K)
        A[4*i] = np.array(v1)
        A[4*i+1] = np.array(v2)
        A[4*i+2] = np.array(v3)
        A[4*i+3] = np.array(v4)
    return A

def get_matrix_PAM(alpha, V=5000, K=20):
    A = np.zeros((K, V))
    dir_vectors = np.random.dirichlet(np.ones(V)*alpha/K, int(K/4))
    for i in range(int(K/4)):
        v1, v2, v3, v4 = generate_vector_group(dir_vectors[i], start=int(4*V/K*i), V=V, K=K)
        A[4*i] = np.array(v1)
        A[4*i+2] = np.array(v2)
        A[4*i+1] = np.array(v3)
        A[4*i+3] = np.array(v4)
    return A

if __name__ == '__main__':
    V = 5000  ## vocab size
    K = 20  ## number of topics

    res=np.zeros((10, K, V))
    alphas= [float(i) for i in range(1,11)]

    for i,alpha in enumerate(alphas):
        alpha = alpha/K
        topics = np.random.dirichlet(alpha*np.ones(V), K)
        res[i,:,:]=topics

    # Pure
    np.save('src_Pure/TopicMatrices.npy', res)

    # LDA
    np.save('src_LDA/TopicMatrices.npy', res)

    # CTM
    ctm_mat = np.zeros((10, K, V))
    for alpha in range(10):
        ctm_mat[alpha, :, :] = get_matrix_CTM(alpha+1)
    np.save('src_CTM/TopicMatrices.npy', ctm_mat)

    # PAM
    pam_mat = np.zeros((10, K, V))
    for alpha in range(10):
        pam_mat[alpha, :, :] = get_matrix_PAM(alpha + 1)
    np.save('src_PAM/TopicMatrices.npy', pam_mat)

    ## Below is for t=2

    V = 500  ## vocab size
    K = 8  ## number of topics

    res_2 = np.zeros((10, K, V))
    alphas = [float(i) for i in range(1, 11)]

    for i, alpha in enumerate(alphas):
        alpha = alpha / K
        topics = np.random.dirichlet(alpha * np.ones(V), K)
        res_2[i, :, :] = topics

    # Pure
    np.save('src_Pure_t=2/TopicMatrices.npy', res_2)

    # LDA
    np.save('src_LDA_t=2/TopicMatrices.npy', res_2)

    # CTM
    ctm_mat_2 = np.zeros((10, K, V))
    for alpha in range(10):
        ctm_mat_2[alpha, :, :] = get_matrix_CTM(alpha + 1, V=V, K=K)
    np.save('src_CTM_t=2/TopicMatrices.npy', ctm_mat_2)

    # PAM
    pam_mat_2 = np.zeros((10, K, V))
    for alpha in range(10):
        pam_mat_2[alpha, :, :] = get_matrix_PAM(alpha + 1, V=V, K=K)
    np.save('src_PAM_t=2/TopicMatrices.npy', pam_mat_2)

    print('Successfully generated topic-word matrix for every topic model')

